library(tidyverse)
library(knitr)
library(plotly) ; library(viridis) ; library(gridExtra) ; library(RColorBrewer) ; library(ggpubr)
library(mgcv)
library(caret) ; library(ROCR) ; library(car) ; library(MLmetrics)
library(knitr) ; library(kableExtra)
library(ROCR)
library(expss)

SFARI_colour_hue = function(r) {
  pal = c('#FF7631','#FFB100','#E8E328','#8CC83F','#62CCA6','#59B9C9','#b3b3b3','#808080','gray','#d9d9d9')[r]
}
load('./../Data/preprocessed_data.RData')
datExpr = datExpr %>% data.frame

# Ridge Regression output
load('./../Data/Ridge_model.RData')

# SFARI Genes
SFARI_genes = read_csv('./../../../SFARI/Data/SFARI_genes_01-03-2020_w_ensembl_IDs.csv')
SFARI_genes = SFARI_genes[!duplicated(SFARI_genes$ID) & !is.na(SFARI_genes$ID),]


# GO Neuronal annotations: regex 'neuron' in GO functional annotations and label the genes that make a match as neuronal
GO_annotations = read.csv('./../Data/genes_GO_annotations.csv')
GO_neuronal = GO_annotations %>% filter(grepl('neuron', go_term)) %>% 
              mutate('ID'=as.character(ensembl_gene_id)) %>% 
              dplyr::select(-ensembl_gene_id) %>% distinct(ID) %>%
              mutate('Neuronal'=1)

# Add all this info to predictions
predictions = predictions %>% left_join(SFARI_genes %>% dplyr::select(ID, `gene-score`), by = 'ID') %>%
              mutate(gene.score = ifelse(is.na(`gene-score`), 
                                         ifelse(ID %in% GO_neuronal$ID, 'Neuronal', 'Others'), 
                                         `gene-score`)) %>%
              dplyr::select(-`gene-score`)

clustering_selected = 'DynamicHybrid'
clusterings = read_csv('./../Data/clusters.csv')
clusterings$Module = clusterings[,clustering_selected] %>% data.frame %>% unlist %>% unname
assigned_module = clusterings %>% dplyr::select(ID, Module)


rm(rownames_dataset, GO_annotations, datGenes, dds, clustering_selected,
   clusterings)


Introduction


In 10_classification_model.html we trained a Ridge regression to assign a probability to each gene with the objective of identifying new candidate SFARI Genes based on their gene expression behaviour captured with the WGCNA pipeline

The model seems to perform well (performance metrics can be found in 10_classification_model.html), but we found a bias related to the level of expression of the genes, in general, with the model assigning higher probabilities to genes with higher levels of expression

This is a problem because we had previously discovered a bias in the SFARI scores related to mean level of expression, which means that this could be a confounding factor in our model and the reason why it seems to perform well

mean_and_sd = data.frame(ID=rownames(datExpr), meanExpr=rowMeans(datExpr), sdExpr=apply(datExpr,1,sd))

plot_data = predictions %>% left_join(mean_and_sd, by='ID')

plot_data %>% ggplot(aes(meanExpr, prob)) + geom_point(alpha=0.1, color='#0099cc') + 
              geom_smooth(method='loess', color='gray', alpha=0.2) +
              xlab('Mean Expression') + ylab('Probability') + 
              ggtitle('Bias in model probabilities by level of expresion') +
              theme_minimal()

rm(mean_and_sd)



Solutions to Bias Problem


This section is based on the paper Identifying and Correcting Label Bias in Machine Learning


Work in fair classification can be categorised into three approaches:



1. Post-processing Approach


After the model has been trained with the bias, perform a post-processing of the classifier outputs. This approach is quite simple to implement but has some downsides:

  • It has limited flexibility

  • Decoupling the training and calibration can lead to models with poor accuracy tradeoff (when training your model it may be focusing on the bias, in our case mean expression, and overlooking more important aspects of your data, such as biological significance)

Note: This is the approach we are going to try in this Markdown



2. Lagrangian Approach


Transforming the problem into a constrained optimisation problem (fairness as the constraint) using Lagrange multipliers.

Some of the downsides of this approach are:

  • The fairness constraints are often irregular and have to be relaxed in order to optimise

  • Training can be difficult, the Lagrangian may not even have a solution to converge to

  • Constrained optimisation can be inherently unstable

  • It can overfit and have poor fairness generalisation

  • According to the paper, it often yields poor trade-offs in fairness and accuracy

Note: It seems quite complicated and has many downsides, so I’m not going to implement this approach



3. Pre-processing Approach


These approaches primarily involve “massaging” the data to remove bias.

Some downsides are:

  • These approaches typically do not perform as well as the state-of-art and come with few theoretical guarantees

Note: In earlier versions of this code, I implemented this approach by trying to remove the level of expression signal from each feature of the dataset (since the Module Membership features capture the bias in an indirect way), but removing the mean expression signal modified the module membership of the genes in big ways sometimes and it didn’t seem to solve the problem in the end, so this proved not to be very useful and wasn’t implemented in this final version



New Method proposed by the paper (weighting technique)


They introduce a new mathematical framework for fairness in which we assume that there exists an unknown but unbiased group truth label function and that the labels observed in the data are assigned by an agent who is possibly biased, but otherwise has the intention of being accurate

Assigning appropriate weights to each sample in the training data and iteratively training a classifier with the new weighted samples leads to an unbiased classifier on the original un-weighted dataset that simultaneously minimises the weighted loss and maximises fairness

Advantages:

  • This approach works also on settings where both the features and the labels are biased

  • It can be used with many ML algorithms

  • It can be applied to many notions of fairness

  • It doesn’t have strict assumptions about the behaviour of the data or the labels

  • According to the paper, it’s fast and robust

  • According to the paper, it consistently leads to fairer classifiers, as well as a better or comparative predictive error than the other methods


Also, this is not important, but I though it was interesting: Since the algorithm simultaneously minimises the weighted loss and maximises fairness via learning the coefficients, it may be interpreted as competing goals with different objective functions, this, it’s a form of a non-zero-sum two-player game

Note: Implemented in 14_bias_correciton_weighting_technique.html






Post Processing Approach Implementation



After the model has been trained with the bias, perform a post-processing of the classifier outputs

Since the effect of the bias is proportional to the mean level of expression of a gene, we can correct it by removing the effect of the mean expression from the probability of the model

Problems:


Remove Bias


The relation between level of expression and probability assigned by the model could be modelled by a linear regression, but we would lose some of the behaviour. Fitting a curve using Generalised Additive Models seems to capture the relation in a much better way, with an \(R^2\) twice as large and no recognisable pattern in the residuals of the regression

test_set = predictions
old_predictions = predictions

plot_data = data.frame('ID'=rownames(datExpr), 'meanExpr'=rowMeans(datExpr)) %>% 
            right_join(test_set, by='ID')

# Fit linear and GAM models to data
lm_fit = lm(prob ~ meanExpr, data = plot_data)
gam_fit = gam(prob ~ s(meanExpr), method = 'REML', data = plot_data)

plot_data = plot_data %>% mutate(lm_res = lm_fit$residuals, gam_res = gam_fit$residuals)

# Plot data
p1 = plot_data %>% ggplot(aes(meanExpr, prob)) + geom_point(alpha=0.1, color='#0099cc') + geom_smooth(method='lm', color='gray', alpha=0.3) +
     xlab('Mean Expression') + ylab('Probability') + ggtitle('Linear Fit') + theme_minimal()

p2 = plot_data %>% ggplot(aes(meanExpr, prob)) + geom_point(alpha=0.1, color='#0099cc') + geom_smooth(method='gam', color='gray', alpha=0.3) +
     xlab('Mean Expression') + ylab('Probability') + ggtitle('GAM fit') + theme_minimal()

p3 = plot_data %>% ggplot(aes(meanExpr, lm_res)) + geom_point(alpha=0.1, color='#ff9900') + 
     geom_smooth(method='gam', color='gray', alpha=0.3) + xlab('Mean Expression') +
     ylab('Residuals') + theme_minimal() + ggtitle(bquote(paste(R^{2},' = ', .(round(summary(lm_fit)$r.squared, 4)))))

p4 = plot_data %>% ggplot(aes(meanExpr, gam_res)) + geom_point(alpha=0.1, color='#ff9900') + 
     geom_smooth(method='gam', color='gray', alpha=0.3) + xlab('Mean Expression') +
     ylab('Residuals') + theme_minimal() + ggtitle(bquote(paste(R^{2},' = ', .(round(summary(gam_fit)$r.sq, 4)))))

grid.arrange(p1, p2, p3, p4, nrow = 2)

rm(p1, p2, p3, p4, lm_fit)


Remove bias from scores with GAM fit


  • Assigning the residuals of the GAM model as the new model probability

  • Adding the mean probability of the original model to each new probability so our new probabilities have the same mean as the original ones

  • As with the plot above, the relation between mean expression and the probability assigned by the model is gone

# Correct Bias
test_set$corrected_score = gam_fit$residuals + mean(test_set$prob)

# Plot results
plot_data = data.frame('ID'=rownames(datExpr), 'meanExpr'=rowMeans(datExpr)) %>% 
            right_join(test_set, by='ID')

plot_data %>% ggplot(aes(meanExpr, corrected_score)) + geom_point(alpha=0.1, color='#0099cc') + 
              geom_smooth(method='gam', color='gray', alpha=0.3) + ylab('Corrected Score') + xlab('Mean Expression') +
              theme_minimal() + ggtitle('Mean expression vs Model score corrected using GAM')

rm(gam_fit)


We could use this corrected scores directly to study the performance of the bias-corrected model, but we wouldn’t have the standard deviation of the performance metrics as we had in 10_classification_model.html where we ran the model several times. To have them here as well, I’m going to run the model many times, correcting the bias in each run


Ridge Regression with Post Processing Bias Correction


Notes:

  • Running the model multiple times to get more acurate measurements of its performance

  • Over-sampling positive samples in the training set to obtain a 1:1 class ratio using SMOTE

  • Performing 10 repetitions of cross validation with 10-folds each

  • Correcting the mean expression bias in each run using the preprocessing approach

### DEFINE FUNCTIONS

create_train_test_sets = function(p, seed){
  
  # Get SFARI Score of all the samples so our train and test sets are balanced for each score
  sample_scores = dataset %>% mutate(ID = rownames(.)) %>% dplyr::select(ID) %>%
                  left_join(original_dataset %>% dplyr::select(ID, gene.score), by = 'ID') %>% 
                  mutate(gene.score = ifelse(is.na(gene.score), 'None', gene.score))

  set.seed(seed)
  train_idx = createDataPartition(sample_scores$gene.score, p = p, list = FALSE)
  
  train_set = dataset[train_idx,]
  test_set = dataset[-train_idx,]
  
  return(list('train_set' = train_set, 'test_set' = test_set))
}



run_model = function(p, seed){
  
  # Create train and test sets
  train_test_sets = create_train_test_sets(p, seed)
  train_set = train_test_sets[['train_set']]
  test_set = train_test_sets[['test_set']]
  
  # Train Model
  train_set = train_set %>% mutate(SFARI = ifelse(SFARI==TRUE, 'SFARI', 'not_SFARI') %>% as.factor)
  lambda_seq = 10^seq(1, -4, by = -.1)
  set.seed(seed)
  k_fold = 10
  cv_repeats = 5
  smote_over_sampling = trainControl(method = 'repeatedcv', number = k_fold, repeats = cv_repeats,
                                     verboseIter = FALSE, classProbs = TRUE, savePredictions = 'final', 
                                     summaryFunction = twoClassSummary, sampling = 'smote')
  fit = train(SFARI ~., data = train_set, method = 'glmnet', trControl = smote_over_sampling, metric = 'ROC',
              tuneGrid = expand.grid(alpha = 0, lambda = lambda_seq))
  
  # Predict labels in test set
  predictions = fit %>% predict(test_set, type = 'prob')
  preds = data.frame('ID' = rownames(test_set), 'prob' = predictions$SFARI) %>% mutate(pred = prob>0.5)
  
  
  #############################################################################################################
  # Correct Mean Expression Bias in predictions
  bias_data = data.frame('ID'=rownames(datExpr), 'meanExpr'=rowMeans(datExpr)) %>% right_join(preds, by='ID')
  gam_fit = gam(prob ~ s(meanExpr), method = 'REML', data = bias_data)
  preds$corrected_prob = gam_fit$residuals + mean(preds$prob)
  preds$corrected_pred = preds$prob>0.5
  #############################################################################################################
  

  # Measure performance of the model
  acc = mean(test_set$SFARI==preds$corrected_pred)
  prec = Precision(test_set$SFARI %>% as.numeric, preds$corrected_pred %>% as.numeric, positive = '1')
  rec = Recall(test_set$SFARI %>% as.numeric, preds$corrected_pred %>% as.numeric, positive = '1')
  F1 = F1_Score(test_set$SFARI %>% as.numeric, preds$corrected_pred %>% as.numeric, positive = '1')
  pred_ROCR = prediction(preds$corrected_prob, test_set$SFARI)
  AUC = performance(pred_ROCR, measure='auc')@y.values[[1]]
  
  # Extract coefficients from features
  coefs = coef(fit$finalModel, fit$bestTune$lambda) %>% as.vector
  
  return(list('acc' = acc, 'prec' = prec, 'rec' = rec, 'F1' = F1, 
              'AUC' = AUC, 'preds' = preds, 'coefs' = coefs))
}


### RUN MODEL

# Parameters
p = 0.75
n_iter = 25
seeds = 123:(123+n_iter-1)

# So the input is the same as in 10_classification_model.html
original_dataset = dataset %>% mutate(ID = rownames(.)) %>% 
                   left_join(old_predictions %>% dplyr::select(ID, gene.score))

# Store outputs
acc = c()
prec = c()
rec = c()
F1 = c()
AUC = c()
predictions = data.frame('ID' = rownames(dataset), 'SFARI' = dataset$SFARI, 'prob' = 0, 'pred' = 0,
                         'corrected_prob' = 0, 'corrected_pred' = 0, 'n' = 0)
coefs = data.frame('var' = c('Intercept', colnames(dataset[,-ncol(dataset)])), 'coef' = 0)

for(seed in seeds){
  
  # Run model
  model_output = run_model(p, seed)
  
  # Update outputs
  acc = c(acc, model_output[['acc']])
  prec = c(prec, model_output[['prec']])
  rec = c(rec, model_output[['rec']])
  F1 = c(F1, model_output[['F1']])
  AUC = c(AUC, model_output[['AUC']])
  preds = model_output[['preds']]
  coefs$coef = coefs$coef + model_output[['coefs']]
  update_preds = preds %>% dplyr::select(-ID) %>% mutate(n=1)
  predictions[predictions$ID %in% preds$ID, c('prob','pred','corrected_prob','corrected_pred','n')] = 
    predictions[predictions$ID %in% preds$ID, c('prob','pred','corrected_prob','corrected_pred','n')] +
     update_preds
}

coefs = coefs %>% mutate(coef = coef/n_iter)
predictions = predictions %>% mutate(prob = prob/n, pred_count = pred, pred = prob>0.5,
                                     corrected_prob = corrected_prob/n, corrected_pred_count = corrected_pred, 
                                     corrected_pred = corrected_prob>0.5)


rm(p, seeds, update_preds, create_train_test_sets, run_model)
test_set = predictions %>% filter(n>0) %>% 
           left_join(dataset %>% mutate(ID = rownames(.)) %>% dplyr::select(ID, GS, MTcor), by = 'ID')
rownames(test_set) = predictions$ID[predictions$n>0]


Performance metrics


Confusion matrix

conf_mat = test_set %>% apply_labels(SFARI = 'Actual Labels', 
                                     corrected_prob = 'Assigned Probability', 
                                     corrected_pred = 'Label Prediction')

cro(conf_mat$SFARI, list(conf_mat$corrected_pred, total()))
 Label Prediction     #Total 
 FALSE   TRUE   
 Actual Labels 
   FALSE  12177 3025   15202
   TRUE  475 305   780
   #Total cases  12652 3330   15982
rm(conf_mat)


Accuracy: Mean = 0.7772 SD = 0.014


Precision: Mean = 0.1077 SD = 0.0093


Recall: Mean = 0.487 SD = 0.0337


F1 score: Mean = 0.1763 SD = 0.0139


ROC Curve: Mean = 0.6478 SD = 0.0165

pred_ROCR = prediction(test_set$corrected_prob, test_set$SFARI)

roc_ROCR = performance(pred_ROCR, measure='tpr', x.measure='fpr')
auc = performance(pred_ROCR, measure='auc')@y.values[[1]]

plot(roc_ROCR, main=paste0('ROC curve (AUC=',round(auc,2),')'), col='#009999')
abline(a=0, b=1, col='#666666')


Lift Curve

lift_ROCR = performance(pred_ROCR, measure='lift', x.measure='rpp')
plot(lift_ROCR, main='Lift curve', col='#86b300')

rm(pred_ROCR, roc_ROCR, AUC, lift_ROCR)




Analyse Results


Score distribution by SFARI Label


SFARI genes have a higher score distribution than the rest, but the overlap is larger than before

plot_data = test_set %>% dplyr::select(corrected_prob, SFARI)

ggplotly(plot_data %>% ggplot(aes(corrected_prob, fill=SFARI, color=SFARI)) + geom_density(alpha=0.3) + 
         geom_vline(xintercept = mean(plot_data$corrected_prob[plot_data$SFARI]), color = '#00C0C2', 
                    linetype='dashed') +
         geom_vline(xintercept = mean(plot_data$corrected_prob[!plot_data$SFARI]), color = '#FF7371', 
                    linetype='dashed') +
        xlab('Score') + ggtitle('Model score distribution by SFARI Label') + theme_minimal())


Score distribution by SFARI Gene Scores


The relation between probability and SFARI Gene Scores weakened but it’s still there

plot_data = test_set %>% mutate(ID=rownames(test_set)) %>% dplyr::select(ID, corrected_prob) %>%
            left_join(original_dataset, by='ID') %>% dplyr::select(ID, corrected_prob, gene.score) %>% 
            apply_labels(gene.score='SFARI Gene score')

cro(plot_data$gene.score)
 #Total 
 SFARI Gene score 
   1  144
   2  205
   3  430
   Neuronal  927
   Others  14261
   #Total cases  15967
mean_vals = plot_data %>% group_by(gene.score) %>% summarise(mean_prob = mean(corrected_prob))

comparisons = list(c('1','2'), c('2','3'), c('3','Neuronal'), c('Neuronal','Others'),
                   c('1','3'), c('3','Others'), c('2','Neuronal'),
                   c('1','Neuronal'), c('2','Others'), c('1','Others'))
increase = 0.08
base = 0.9
pos_y_comparisons = c(rep(base, 4), rep(base + increase, 2), base + 2:5*increase)

plot_data %>% filter(!is.na(gene.score)) %>% ggplot(aes(gene.score, corrected_prob, fill=gene.score)) + 
              geom_boxplot(outlier.colour='#cccccc', outlier.shape='o', outlier.size=3) +
              stat_compare_means(comparisons = comparisons, label = 'p.signif', method = 't.test', 
                                 method.args = list(var.equal = FALSE), label.y = pos_y_comparisons, 
                                 tip.length = .02) +
              scale_fill_manual(values=SFARI_colour_hue(r=c(1:3,8,7))) + 
              ggtitle('Distribution of probabilities by SFARI score') +
              xlab('SFARI score') + ylab('Probability') + theme_minimal() + theme(legend.position = 'none')

rm(mean_vals, increase, base, pos_y_comparisons)


Genes with the highest Probabilities


  • The concentration of SFARI genes decrease from 1:3 to 1:5

  • The genes with the highest probabilities are no longer SFARI Genes

test_set %>% dplyr::select(corrected_prob, SFARI) %>% mutate(ID = rownames(test_set)) %>% 
             arrange(desc(corrected_prob)) %>% top_n(50, wt=corrected_prob) %>%
             left_join(old_predictions %>% dplyr::select(ID, gene.score, external_gene_id, MTcor, GS), 
                       by = 'ID') %>%
             dplyr::rename('GeneSymbol' = external_gene_id, 'Probability' = corrected_prob, 
                           'ModuleDiagnosis_corr' = MTcor, 'GeneSignificance' = GS) %>%
             mutate(ModuleDiagnosis_corr = round(ModuleDiagnosis_corr,4), Probability = round(Probability,4), 
                    GeneSignificance = round(GeneSignificance,4)) %>%
             left_join(assigned_module, by = 'ID') %>%
             dplyr::select(GeneSymbol, GeneSignificance, ModuleDiagnosis_corr, Module, Probability,
                           gene.score) %>%
             kable(caption = 'Genes with highest model probabilities from the test set') %>% 
             kable_styling(full_width = F)
Genes with highest model probabilities from the test set
GeneSymbol GeneSignificance ModuleDiagnosis_corr Module Probability gene.score
FMNL1 -0.2229 -0.3847 #00B7E7 0.8760 Others
HUNK -0.3276 -0.3731 #4BB400 0.8537 Others
AMPD3 -0.2812 -0.9003 #00BADE 0.8372 Others
POM121C -0.2811 -0.5573 #00C0BF 0.8363 Others
RPRD2 -0.0931 0.3428 #00C08C 0.8341 Others
ETV6 -0.1439 -0.6508 #E76BF3 0.8309 Others
SAP130 -0.2497 -0.5573 #00C0BF 0.8245 Others
FOXJ2 0.2507 0.3428 #00C08C 0.8201 Others
TLE3 0.4865 0.5341 #EF67EB 0.8131 Others
MIDN -0.0560 -0.5573 #00C0BF 0.8103 Others
PCDHB13 -0.0405 -0.4661 #1FB700 0.8086 Others
CCRN4L -0.0277 0.5341 #EF67EB 0.8061 Others
ITGAX -0.0070 -0.0127 #00BF7D 0.8055 Others
FAM193A 0.0348 0.0721 #96A900 0.8053 Others
SCAF4 0.0161 -0.5573 #00C0BF 0.8021 Others
TBC1D8 0.0962 0.7206 #FF66A9 0.7994 Others
DROSHA -0.4103 -0.5573 #00C0BF 0.7979 Others
MBD5 0.1916 0.0721 #96A900 0.7961 1
SETD1B -0.1577 -0.5573 #00C0BF 0.7934 3
PLAGL2 0.3913 0.5341 #EF67EB 0.7933 Others
BCL9 0.1303 -0.0245 #FE6D8D 0.7928 Others
SAMSN1 0.4237 0.7206 #FF66A9 0.7925 Others
ZNF804A -0.2756 -0.3731 #4BB400 0.7904 2
C20orf112 0.0294 0.3428 #00C08C 0.7871 Others
RNF111 -0.2393 0.0721 #96A900 0.7853 Others
BDNF 0.3529 0.5341 #EF67EB 0.7844 Neuronal
EP300 -0.0235 0.3428 #00C08C 0.7830 1
TCF7L2 0.1314 -0.5573 #00C0BF 0.7830 2
AGO2 -0.1473 -0.5573 #00C0BF 0.7829 Others
IGF1 -0.2510 -0.3731 #4BB400 0.7825 Others
AQP1 -0.0931 -0.9003 #00BADE 0.7814 Others
NUP214 0.2331 0.0721 #96A900 0.7809 Others
PARVG 0.5663 0.7206 #FF66A9 0.7805 Others
CACNG3 -0.4698 -0.6770 #00A6FF 0.7803 Others
ARID1B 0.2675 0.3428 #00C08C 0.7793 1
RFX7 0.1381 0.0721 #96A900 0.7785 Others
SNX25 0.0969 -0.0127 #00BF7D 0.7776 Others
BCOR 0.2903 0.7651 #77AF00 0.7773 Others
PIEZO2 -0.0776 -0.6508 #E76BF3 0.7763 Others
CHAMP1 -0.4095 -0.5573 #00C0BF 0.7756 1
ZNF275 -0.0267 -0.5573 #00C0BF 0.7744 Others
CNOT3 0.1375 0.0721 #96A900 0.7719 1
CECR6 -0.2834 -0.5573 #00C0BF 0.7716 Others
ZNF592 0.2727 0.3428 #00C08C 0.7706 Others
PHF12 -0.1795 -0.0245 #FE6D8D 0.7695 Others
ZFPM2 0.2347 0.3428 #00C08C 0.7693 Others
BAZ2A 0.0518 0.0721 #96A900 0.7682 Others
TET3 0.4848 0.5341 #EF67EB 0.7678 Others
FOXP4 -0.2121 -0.5573 #00C0BF 0.7675 Others
PLXNC1 -0.0093 0.5341 #EF67EB 0.7666 Others





Negative samples distribution


The objective of this model is to identify candidate SFARI genes. For this, we are going to focus on the negative samples (the non-SFARI genes)

negative_set = test_set %>% filter(!SFARI)

negative_set_table = negative_set %>% apply_labels(corrected_prob = 'Assigned Probability', 
                                                   corrected_pred = 'Label Prediction')

cro(negative_set_table$corrected_pred)
 #Total 
 Label Prediction 
   FALSE  12177
   TRUE  3025
   #Total cases  15202

3081 genes are predicted as ASD-related


negative_set %>% ggplot(aes(corrected_prob)) + geom_density(color='#F8766D', fill='#F8766D', alpha=0.5) +
                 geom_vline(xintercept=0.5, color='#333333', linetype='dotted') + xlab('Probability') +
                 ggtitle('Probability distribution of the Negative samples in the Test Set') + 
                 theme_minimal()


Comparison with the original model’s probabilities:

  • The genes with the highest scores were affected the most as a group

  • More genes got their score increased than decreased but on average, the ones that got it decreased had a bigger change

negative_set %>% mutate(diff = abs(prob-corrected_prob)) %>% 
             ggplot(aes(prob, corrected_prob, color = diff)) + geom_point(alpha=0.2) + scale_color_viridis() + 
             geom_abline(slope=1, intercept=0, color='gray', linetype='dashed') + 
             geom_smooth(color='#666666', alpha=0.5, se=TRUE, size=0.5) + coord_fixed() +
             xlab('Original probability') + ylab('Corrected probability') + theme_minimal() + theme(legend.position = 'none')

negative_set_table = negative_set %>% apply_labels(corrected_prob = 'Corrected Probability', 
                                                   corrected_pred = 'Corrected Class Prediction',
                                                   pred = 'Original Class Prediction')

cro(negative_set_table$pred, list(negative_set_table$corrected_pred, total()))
 Corrected Class Prediction     #Total 
 FALSE   TRUE   
 Original Class Prediction 
   FALSE  11668 453   12121
   TRUE  509 2572   3081
   #Total cases  12177 3025   15202

94% of the genes maintained their original predicted class

rm(negative_set_table)

Probability and Gene Significance


The relation is not as strong as before in the highest scores

*The transparent verison of the trend line is the original trend line

negative_set %>% ggplot(aes(corrected_prob, GS, color=MTcor)) + geom_point() + 
                 geom_smooth(method='gam', color='#666666') + ylab('Gene Significance') +
                 geom_line(stat='smooth', method='gam', color='#666666', alpha=0.5, size=1.2, aes(x=prob)) +
                 geom_hline(yintercept=mean(negative_set$GS), color='gray', linetype='dashed') +
                 scale_color_gradientn(colours=c('#F8766D','white','#00BFC4')) + xlab('Corrected Score') +
                 ggtitle('Relation between the Model\'s Corrected Score and Gene Significance') +
                 theme_minimal()

Summarised version of score vs mean expression, plotting by module instead of by gene

The difference in the trend lines between this plot and the one above is that the one above takes all the points into consideration while this considers each module as an observation by itself, so the top one is strongly affected by big modules and the bottom one treats all modules the same

The transparent version of each point and trend lines are the original values and trends before the bias correction

plot_data = negative_set %>% mutate(ID = rownames(.)) %>% left_join(assigned_module, by = 'ID') %>%
            group_by(MTcor, Module) %>% summarise(mean = mean(prob), sd = sd(prob),
                                                  new_mean = mean(corrected_prob),
                                                  new_sd = sd(corrected_prob), n = n()) %>%
            mutate(MTcor_sign = ifelse(MTcor>0, 'Positive', 'Negative')) %>% 
            dplyr::select(Module, MTcor, MTcor_sign, mean, new_mean, sd, new_sd, n) %>% distinct()
colnames(plot_data)[1] = 'ID'

ggplotly(plot_data %>% ggplot(aes(MTcor, new_mean, size=n, color=MTcor_sign)) + geom_point(aes(id = ID)) + 
         geom_smooth(method='loess', color='gray', se=FALSE) + geom_smooth(method='lm', se=FALSE) + 
         geom_point(aes(y=mean), alpha=0.3) + 
         geom_line(stat='smooth', method='loess', color='gray', se=FALSE, alpha=0.3, size=1.2, aes(y=mean)) + 
         geom_line(stat='smooth', method='lm', se=FALSE, alpha=0.3, size=1.2, aes(y=mean)) + 
         xlab('Module-Diagnosis correlation') + ylab('Mean Corrected Score by Module') + 
         theme_minimal() + theme(legend.position='none'))


Probability and mean level of expression


Check if correcting by gene also corrected by module: Yes, but not enough to remove the bias completely

mean_and_sd = data.frame(ID=rownames(datExpr), meanExpr=rowMeans(datExpr), sdExpr=apply(datExpr,1,sd))

plot_data = negative_set %>% mutate(ID=rownames(test_set)[!test_set$SFARI]) %>% 
            left_join(mean_and_sd, by='ID') %>% 
            left_join(assigned_module, by='ID')

plot_data2 = plot_data %>% group_by(Module) %>% summarise(meanExpr = mean(meanExpr), meanProb = mean(prob), 
                                                          new_meanProb = mean(corrected_prob), n=n())

ggplotly(plot_data2 %>% ggplot(aes(meanExpr, new_meanProb, size=n)) + 
         geom_point(color=plot_data2$Module) + geom_point(color=plot_data2$Module, alpha=0.3, aes(y=meanProb)) + 
         geom_smooth(method='loess', se=TRUE, color='gray', alpha=0.1, size=0.7) + 
         geom_line(stat='smooth', method='loess', se=TRUE, color='gray', alpha=0.4, size=1.2, aes(y=meanProb)) +
         xlab('Mean Expression') + ylab('Corrected Probability') +  
         ggtitle('Mean expression vs corrected Model score by Module') +
         theme_minimal() + theme(legend.position='none'))
rm(plot_data2, mean_and_sd)


Probability and LFC


The relation seems to have gotten a bit stronger for the over-expressed genes and a bit weaker for the under-expressed genes

plot_data = negative_set %>% mutate(ID=rownames(test_set)[!test_set$SFARI]) %>% 
            left_join(DE_info %>% data.frame %>% mutate(ID=rownames(.)), by='ID')

plot_data %>% ggplot(aes(log2FoldChange, corrected_prob)) + geom_point(alpha=0.1, color='#0099cc') + 
              geom_smooth(method='loess', color='gray', alpha=0.1) + 
              geom_line(stat='smooth', method='loess', color='gray', alpha=0.4, size=1.5, aes(y=prob)) +
              xlab('LFC') + ylab('Corrected Probability') +
              theme_minimal() + ggtitle('LFC vs model probability by gene')

p1 = plot_data %>% filter(log2FoldChange<0) %>% mutate(DE = padj<0.05) %>% 
     ggplot(aes(log2FoldChange, corrected_prob, color=DE)) + geom_point(alpha=0.1) + 
     geom_smooth(method='loess', alpha=0.1) + xlab('') + ylab('Corrected Probability') + 
     ylim(c(min(plot_data$corrected_prob), max(plot_data$corrected_prob))) + 
     geom_line(stat='smooth', method='loess', alpha=0.4, size=1.5, aes(y=prob, color = DE)) +
     theme_minimal() + theme(legend.position = 'none', plot.margin=unit(c(1,-0.3,1,1), 'cm'))

p2 = plot_data %>% filter(log2FoldChange>=0) %>% mutate(DE = padj<0.05) %>% 
     ggplot(aes(log2FoldChange, corrected_prob, color=DE)) + geom_point(alpha=0.1) + 
     geom_smooth(method='loess', alpha=0.1) + xlab('') + ylab('Corrected Probability') + ylab('') +
     scale_y_continuous(position = 'right', 
                        limits = c(min(plot_data$corrected_prob), max(plot_data$corrected_prob))) +
     geom_line(stat='smooth', method = 'loess', alpha=0.4, size=1.5, aes(y = prob, color = DE)) +
     theme_minimal() + theme(plot.margin = unit(c(1,1,1,-0.3), 'cm'), axis.ticks.y = element_blank())

grid.arrange(p1, p2, nrow=1, top = 'LFC vs model probability by gene', bottom = 'LFC')

rm(p1, p2)


Probability and Module-Diagnosis correlation


Not much change

module_score = negative_set %>% mutate(ID=rownames(test_set)[!test_set$SFARI]) %>%
               left_join(old_predictions %>% dplyr::select(ID, gene.score), by='ID') %>%
               left_join(assigned_module, by = 'ID') %>%
               dplyr::select(ID, prob, corrected_prob, Module, MTcor) %>% 
               left_join(data.frame(MTcor=unique(dataset$MTcor)) %>% arrange(by=MTcor) %>% 
                         mutate(order=1:length(unique(dataset$MTcor))), by='MTcor')

ggplotly(module_score %>% ggplot(aes(MTcor, corrected_prob)) + 
         geom_point(color=module_score$Module, aes(id=ID, alpha=corrected_prob^4)) +
         geom_hline(yintercept=mean(module_score$corrected_prob), color='gray', linetype='dotted') + 
         geom_line(stat='smooth', method = 'loess', color='gray', alpha=0.5, size=1.5, aes(x=MTcor, y=prob)) +
         geom_smooth(color='gray', method = 'loess', se = FALSE, alpha=0.3) + theme_minimal() + 
         xlab('Module-Diagnosis correlation') + ylab('Corrected Score'))



Conclusion


This bias correction seems to be working partially but not entirely, it doesn’t make a big change in the performance of the model, but we may be losing a bit of biological signal on the way (LFC), mainly for under-expressed genes


Saving results

write.csv(test_set, file='./../Data/RM_post_proc_bias_correction.csv', row.names = TRUE)




Session info

sessionInfo()
## R version 3.6.3 (2020-02-29)
## Platform: x86_64-pc-linux-gnu (64-bit)
## Running under: Ubuntu 18.04.4 LTS
## 
## Matrix products: default
## BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.7.1
## LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.7.1
## 
## locale:
##  [1] LC_CTYPE=en_GB.UTF-8       LC_NUMERIC=C              
##  [3] LC_TIME=en_GB.UTF-8        LC_COLLATE=en_GB.UTF-8    
##  [5] LC_MONETARY=en_GB.UTF-8    LC_MESSAGES=en_GB.UTF-8   
##  [7] LC_PAPER=en_GB.UTF-8       LC_NAME=C                 
##  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
## [11] LC_MEASUREMENT=en_GB.UTF-8 LC_IDENTIFICATION=C       
## 
## attached base packages:
## [1] grid      stats     graphics  grDevices utils     datasets  methods  
## [8] base     
## 
## other attached packages:
##  [1] DMwR_0.4.1         expss_0.10.2       kableExtra_1.1.0   MLmetrics_1.1.1   
##  [5] car_3.0-7          carData_3.0-3      ROCR_1.0-7         gplots_3.0.3      
##  [9] caret_6.0-86       lattice_0.20-41    mgcv_1.8-31        nlme_3.1-147      
## [13] ggpubr_0.2.5       magrittr_1.5       RColorBrewer_1.1-2 gridExtra_2.3     
## [17] viridis_0.5.1      viridisLite_0.3.0  plotly_4.9.2       knitr_1.28        
## [21] forcats_0.5.0      stringr_1.4.0      dplyr_1.0.0        purrr_0.3.4       
## [25] readr_1.3.1        tidyr_1.1.0        tibble_3.0.1       ggplot2_3.3.2     
## [29] tidyverse_1.3.0   
## 
## loaded via a namespace (and not attached):
##   [1] readxl_1.3.1                backports_1.1.8            
##   [3] Hmisc_4.4-0                 plyr_1.8.6                 
##   [5] lazyeval_0.2.2              splines_3.6.3              
##   [7] crosstalk_1.1.0.1           BiocParallel_1.18.1        
##   [9] GenomeInfoDb_1.20.0         digest_0.6.25              
##  [11] foreach_1.5.0               htmltools_0.4.0            
##  [13] gdata_2.18.0                fansi_0.4.1                
##  [15] memoise_1.1.0               checkmate_2.0.0            
##  [17] cluster_2.1.0               openxlsx_4.1.4             
##  [19] annotate_1.62.0             recipes_0.1.10             
##  [21] modelr_0.1.6                gower_0.2.1                
##  [23] matrixStats_0.56.0          xts_0.12-0                 
##  [25] jpeg_0.1-8.1                colorspace_1.4-1           
##  [27] blob_1.2.1                  rvest_0.3.5                
##  [29] haven_2.2.0                 xfun_0.12                  
##  [31] crayon_1.3.4                RCurl_1.98-1.2             
##  [33] jsonlite_1.7.0              genefilter_1.66.0          
##  [35] zoo_1.8-8                   survival_3.1-12            
##  [37] iterators_1.0.12            glue_1.4.1                 
##  [39] gtable_0.3.0                ipred_0.9-9                
##  [41] zlibbioc_1.30.0             XVector_0.24.0             
##  [43] webshot_0.5.2               DelayedArray_0.10.0        
##  [45] shape_1.4.4                 quantmod_0.4.17            
##  [47] BiocGenerics_0.30.0         abind_1.4-5                
##  [49] scales_1.1.1                DBI_1.1.0                  
##  [51] Rcpp_1.0.4.6                xtable_1.8-4               
##  [53] htmlTable_1.13.3            bit_1.1-15.2               
##  [55] foreign_0.8-76              Formula_1.2-3              
##  [57] stats4_3.6.3                lava_1.6.7                 
##  [59] prodlim_2019.11.13          glmnet_3.0-2               
##  [61] htmlwidgets_1.5.1           httr_1.4.1                 
##  [63] acepack_1.4.1               ellipsis_0.3.1             
##  [65] farver_2.0.3                XML_3.99-0.3               
##  [67] pkgconfig_2.0.3             nnet_7.3-14                
##  [69] dbplyr_1.4.2                locfit_1.5-9.4             
##  [71] labeling_0.3                AnnotationDbi_1.46.1       
##  [73] tidyselect_1.1.0            rlang_0.4.6                
##  [75] reshape2_1.4.4              munsell_0.5.0              
##  [77] cellranger_1.1.0            tools_3.6.3                
##  [79] cli_2.0.2                   RSQLite_2.2.0              
##  [81] generics_0.0.2              broom_0.5.5                
##  [83] evaluate_0.14               yaml_2.2.1                 
##  [85] bit64_0.9-7                 ModelMetrics_1.2.2.2       
##  [87] fs_1.4.0                    zip_2.0.4                  
##  [89] caTools_1.18.0              xml2_1.2.5                 
##  [91] compiler_3.6.3              rstudioapi_0.11            
##  [93] png_0.1-7                   curl_4.3                   
##  [95] ggsignif_0.6.0              reprex_0.3.0               
##  [97] geneplotter_1.62.0          stringi_1.4.6              
##  [99] highr_0.8                   Matrix_1.2-18              
## [101] vctrs_0.3.1                 pillar_1.4.4               
## [103] lifecycle_0.2.0             data.table_1.12.8          
## [105] bitops_1.0-6                GenomicRanges_1.36.1       
## [107] R6_2.4.1                    latticeExtra_0.6-29        
## [109] KernSmooth_2.23-17          rio_0.5.16                 
## [111] IRanges_2.18.3              codetools_0.2-16           
## [113] MASS_7.3-51.6               gtools_3.8.2               
## [115] assertthat_0.2.1            SummarizedExperiment_1.14.1
## [117] DESeq2_1.24.0               withr_2.2.0                
## [119] S4Vectors_0.22.1            GenomeInfoDbData_1.2.1     
## [121] parallel_3.6.3              hms_0.5.3                  
## [123] rpart_4.1-15                timeDate_3043.102          
## [125] class_7.3-17                rmarkdown_2.1              
## [127] TTR_0.23-6                  pROC_1.16.2                
## [129] Biobase_2.44.0              lubridate_1.7.4            
## [131] base64enc_0.1-3